import torch
from torch import nn

a = torch.randn(10, 5, 6)
# 线性层可以矩阵乘法改变输入矩阵的最后一个维度
layer = nn.Linear(6, 12)
b = layer(a)
print(b.shape)
